#=
Code that holds all important model functions and estimation procedures
=#

#struct to hold model primitives that will not be estimated
@with_kw struct Primitives
    β::Float64 = 0.479 #discount rate: 0.96^18
    μ_ε3::Array{Float64,1} = [0.02, 0.07] #mean of later life HC shock
    σ_ε3::Array{Float64,1} = [0.24, 0.24] #SD of later life HC shock

    #racial skill price/fertility/marriage penalties
    w_mods::Array{Float64, 1} = [0.0, -0.145, -0.087]
    marr_mod_hs::Array{Float64, 1} = [0.0, -0.632, -0.122]
    marr_mod_coll::Array{Float64, 1} = [0.0, -0.583, -0.179]
    fert_mod_coll_marr::Array{Float64, 1} = [0.0, -0.073, 0.007]
    fert_mod_coll_nmarr::Array{Float64, 1} = [0.0, 0.432, 0.252]
    fert_mod_hs_marr::Array{Float64, 1} = [0.0, -0.049, 0.176]
    fert_mod_hs_nmarr::Array{Float64, 1} = [0.0, 0.352, 0.315]

    #stuff read in from CSV files
    state_chars::Array{Float64,2} = zeros(1,4) #state characteristics read in from CSV file
    fips::Array{Float64,1} = state_chars[:,1] #fips codes
    skill_prices::Array{Float64,2} = state_chars[:,2:3] #relative skill prices
    govt_inv::Array{Float64,1} = state_chars[:,4] #normalized government investmnet in children
    s_t_ratios::Array{Float64,1} = state_chars[:,5] #student-teacher ratios
    prices::Array{Float64,1} = state_chars[:,6] #consumption prices
    pops::Array{Float64,1} = state_chars[:,7] #state populations
    tuitions::Array{Float64,1} = state_chars[:,8] #public tu tion costs
    taxes::Array{Float64, 1} = state_chars[:,9]
    divs::Array{Float64,1} = state_chars[:,10] #state census divisions
    amenity::Array{Float64,1} = state_chars[:,11]
    amen_grid::Array{Float64, 2} = zeros(50, 7) #grid of amenities for robustness exercises
    regions::Array{Int64, 1} = state_chars[:, 12]

    #financial aid schedule
    aid_sched::Array{Float64, 2} = zeros(1, 1)

    #other stuff
    prox_mat::Array{Float64,2} = zeros(1,1) #location distance matrix
    kid_marr_hs::Array{Float64, 2} = zeros(50,6) #kid marriage probabilities
    kid_fert_marr_hs::Array{Float64,2} = zeros(50,6) #kid fertilyt probabilities, married
    kid_fert_nonmarr_hs::Array{Float64,2} = zeros(50,6) #kid fertilyt probabilities, married
    kid_marr_coll::Array{Float64, 2} = zeros(50,6) #kid marriage probabilities
    kid_fert_marr_coll::Array{Float64,2} = zeros(50,6) #kid fertilyt probabilities, married
    kid_fert_nonmarr_coll::Array{Float64,2} = zeros(50,6) #kid fertilyt probabilities, married
    mig_mat::Array{Float64,2} = zeros(50,50) #migration rates observed in data

    #grid points
    nh::Int64 = 25 #number of human capital grid points
    nl::Int64 = length(fips) #number of locations
    na::Int64 = 5 #number of ability states -- keeping it coarse for now
    ne::Int64 = 5 #number of possible earnings shocks
    nf::Int64 = 2 #fertility states
    nm::Int64 = 2 #marriage states
    nS::Int64 = 2 #college states
    nR::Int64 = 3 #races

    #parent HC distribution
    μ_h_raw::Float64 = 0.9021
    σ_h_raw::Float64 = 0.634

    #parent distribution
    μ_h::Float64 = log(μ_h_raw^2 / sqrt(μ_h_raw^2 + σ_h_raw^2))
    σ_h::Float64 = log(1 + σ_h_raw^2/μ_h_raw^2)
    hc_norm::Float64 = 0.891 #amount by which to normalize human capital when computing moving cost adjustments
    tuition_share::Float64 = 0.5 #share of tuition covered by parent. Start at 0.5 for now.
end

#struct to hold model parameters to be estimated
@with_kw mutable struct Estimands
    guess::Vector{Float64} = zeros(14)

    θ::Float64 = guess[1] #parental altruism
    ρ_ha::Float64 = guess[2] #persistence of learning abilities
    μ_a::Float64 = guess[3] #mean of learning abilities
    σ_a::Float64 = guess[4] #SD of learnming abilities
    ξ::Float64 = guess[5] #child-to-adult HC anchor
    ϕ::Float64 = guess[6] #time share of child HC production
    κ::Float64 = guess[7] #Ben-Porath HC accumulation parameter
    σ_ε2::Float64 = guess[8] #young adult HC shock SD
    α_1::Float64 = guess[9] #amenity
    Δ_1::Float64 = guess[10] #moving fixed cost
    Δ_2::Float64 = guess[11] #HC moving distance cost modifier
    Δ_3::Float64 = guess[12] #proximity moving cost reduction
    Δ_4::Float64 = guess[13] #college moving cost effect
    Δ_5::Float64 = guess[20] #population component
    σ_ζ::Float64 = guess[14] #T1EV shock scale parameter

    #college paramters
    η_1::Float64 = guess[15] #fixed component
    η_2::Float64 = guess[16] #ability component
    #η_3::Float64 = guess[17] #HC effect
    η_3::Float64 = guess[17] #parent prestige effect
    σ_η::Float64 = guess[18] #college preference shocks spread (if I add state size moving cost effect and p3 moving cost ajdustment, have 20 params. manageable.)

    #dynamic moving parameters
    Δ_mod::Float64 = guess[19] #period-3 moving fixed cost adjustment

    #other amenities
    #α_2::Float64 = guess[27]
    #α_3::Float64 = guess[28]
    α_2::Float64 = 0.0
    α_3::Float64 = 0.0

    #raca parameters
    η_1_b::Float64 = guess[21]
    η_1_h::Float64 = guess[22]

    Δ_mod_b::Float64 = guess[23]
    Δ_mod_h::Float64 = guess[24]

    μ_a_b::Float64 = guess[3]
    μ_a_h::Float64 = guess[3]

    ξ_b::Float64 = guess[25]
    ξ_h::Float64 = guess[26]

    #region-spefific ability modifiers: regions 2, 3, 4
    μ_a_r_1::Float64 = guess[3]  #omitte case
    μ_a_r_2::Float64 = guess[27]
    μ_a_r_3::Float64 = guess[28]
    μ_a_r_4::Float64 = guess[29]
end

#struct to hold grids used in computation
mutable struct Grids
    h_grid::Array{Float64,1} #human capital grid
    loc_grid::Array{Float64,1} #list of locations
    ε2_grid::Array{Float64,1} #earnings shocks grid, period 2
    ε3_grid_hs::Array{Float64,1} #earnings shocks grid, period 3
    ε3_grid_coll::Array{Float64,1} #earnings shocks grid, period 3
    a_grid::Array{Float64,1} #learning ability grid
end

#struct to hold computed value/policy functions and other model results
mutable struct Results
    val_func_2::SharedArray{Float64, 6} #period 2 value function after move
    val_func_3::SharedArray{Float64, 6} #period 4 value function
    n_func::SharedArray{Float64, 6} #n policy function
    x_func::SharedArray{Float64, 5} #x policy function
    t_func::SharedArray{Float64, 5} #t policy function
end

#initialize model parameters and primitives
function Initialize(package::Vector{Matrix{Any}}, package_moments::Vector{Matrix{Float64}}; p2::Int64 = 0)

    #unpack migration matrix and define how we'rd normalizing migraton rates
    mig_mat = package_moments[2]

    #update amenity proxy
    #pops = zeros(50)
    #for i = 1:50
        #pops[i] = mean(mig_mat[:,i])*100
    #end

    #pops = package[13][:,2]*100

    #initialize primitives
    prim = Primitives(state_chars = package[1], prox_mat = package[3],
    kid_marr_hs = package[5], kid_fert_marr_hs = package[6], kid_fert_nonmarr_hs = package[7],
    kid_marr_coll = package[8], kid_fert_marr_coll = package[9], kid_fert_nonmarr_coll = package[10],
    mig_mat = mig_mat, aid_sched = package[12], amen_grid = package[15]) #initialize primtiives

    #update if working with 2010 stuff
    if p2 == 1
        prim = Primitives(state_chars = package[2], prox_mat = package[3],
        kid_marr_hs = package[5], kid_fert_marr_hs = package[6], kid_fert_nonmarr_hs = package[7],
        kid_marr_coll = package[8], kid_fert_marr_coll = package[9], kid_fert_nonmarr_coll = package[10],
        mig_mat = mig_mat, aid_sched = package[12], amen_grid = package[16]) #initialize primtiives
    end

    #initialize results vector
    @unpack nh, nl, na, nf, nm, nS, nR = prim
    val_func_2 = SharedArray{Float64}(na, nh, nl, nl, 2, nR)
    n_func = SharedArray{Float64}(na, nh, nl, nl, 2, nR)

    val_func_3 = SharedArray{Float64}(nh, nl, na, nf, nS, nR)
    x_func = SharedArray{Float64}(nh, nl, na, nS, nR)
    t_func = SharedArray{Float64}(nh, nl, na, nS, nR)
    res = Results(val_func_2, val_func_3, n_func, x_func, t_func)
    prim, res #return deliverables
end

#initialize computing and results grids
function Initialize_grids(prim::Primitives, guess::Array{Float64,1})

    est = Estimands(guess = guess)

    @unpack nh, nl, na, ne, μ_ε3, σ_ε3 = prim #unpack some primtiives
    @unpack ρ_ha, μ_a, σ_ε2, σ_a = est #unpack some Estimands

    #preallocate computing grids
    loc_grid = collect(1:nl) #location grid
    ε2_grid = discretize_lognormal(-σ_ε2^2/2, σ_ε2, ne) #discretize log normal distribution
    ε3_grid_hs = discretize_lognormal(μ_ε3[1], σ_ε3[1], ne) #discretize log normal distribution
    ε3_grid_coll = discretize_lognormal(μ_ε3[2], σ_ε3[2], ne) #discretize log normal distribution
    a_grid = discretize_lognormal(μ_a, σ_a, na)
    #a_grid = discretize_normal(μ_a, σ_a, na)

    #approximate lowest and highest possible HC values for periods 2 and 3
    h_min = 0.1 #while possible to go even lower, start here to avoid computer rounding anything down to -Inf, as convergence then gets weird.
    h_max = 6.0 #translates to roughly 99% income percentile for adult stage
    h_grid = collect(range(h_min, length = nh, stop = h_max)) #fill in HC grid
    grids = Grids(h_grid, loc_grid, ε2_grid, ε3_grid_hs, ε3_grid_coll, a_grid) #initialize grids
    grids, est #return deliverables
end

#value function iteration protocol
function V_iterate(prim::Primitives, est::Estimands, res::Results, grids::Grids; tol::Float64 = 1e-3, l_select::Int64 = 0, cfact::Int64 = 0, aopt::Int64 = 0)
    err = 100.0 #big starting error
    n = 1
    max = 11

    while err>tol #begin main value function iteration loop
        v2_next = zeros(prim.na, prim.nh, prim.nl, prim.nl, 2, prim.nR)

        #now handle counterfactuals retention policy. this method should have no speed impact
        Bellman_3(prim, est, res, grids; aopt = aopt) #update guess of v3 and policy function

        if l_select == 0
            v2_next = Bellman_2(prim, est, res, grids; aopt = aopt) #get new guess of v2 and policy functions
        elseif l_select>0
            v2_next = Bellman_2(prim, est, res, grids; l_select=l_select, cfact = cfact, aopt = aopt) #get new guess of v2 and policy functions
        end

        err = maximum(abs.(v2_next.-res.val_func_2)) #update error
        res.val_func_2 = v2_next #update guess of period 2 value function
        n+=1
        if n == max
            break
        end
    end
    println("Value functions converged in ", n-1, " iterations")
end

#period 3 Bellman operator
function Bellman_3(prim::Primitives, est::Estimands, res::Results, grids::Grids; aopt::Int64 = 0)
    @unpack na, nh, nl, prices, skill_prices, amenity, β, ne, nR, nS, taxes, amen_grid, w_mods = prim #unpack some primitives
    @unpack h_grid, ε3_grid_hs, ε3_grid_coll = grids
    @unpack θ, α_1, α_2, α_3 = est
    @unpack val_func_2 = res #more unpacking
    ε3_grid = [ε3_grid_hs, ε3_grid_coll]


    #figure out what our amenities are
    amen_1 = amenity
    amen_2 = zeros(50)
    amen_3 = zeros(50)
    if aopt==4 #switch to college gratio
        amen_1 = amen_grid[:, 7]
    elseif aopt>0 && aopt<4
        a_ind = 2*(aopt-1) + 1
        amen_2 = amen_grid[:, a_ind]
        amen_3 = amen_grid[:, a_ind+1]
    end

    x_grid, t_grid = collect(0:0.025:0.5), collect(0:0.025:0.5)
    nx, nt = length(x_grid), length(t_grid)

    #fill in value function for parents with kids
    @sync @distributed for l = 1:nl #outer loop to distriute over
        for ap = 1:na, S = 1:nS, R = 1:nR
            x_lower = 1 #bad candidate max value, initial lower bound for x
            for h = 1:nh
                candidate_max = -1e10
                for x = x_lower:nx, t = 1:nt #grid search
                    val = Obj_func_v3(x_grid[x], t_grid[t], prim, est, res, grids, h, l, ap, S, R) #value
                    if val>candidate_max #new max!
                        candidate_max = val #update candidate max
                        res.x_func[h, l, ap, S, R] = x_grid[x] #update x policy function
                        res.t_func[h, l, ap, S, R] = t_grid[t] #update t policy function
                        x_lower = x #parent monetary investment weakly increasing in HC -- saves time.
                    end
                end
                res.val_func_3[h, l, ap, 1, S, R] = candidate_max + (1+β) * (α_1 * amen_1[l] + α_2 * amen_2[l] + α_3 * amen_3[l])  #update value function, filling in for all birth states
            end
        end
    end

    #now for parents without kids -- much simpler
    for h = 1:nh, l = 1:nl, ap = 1:na, S = 1:nS, R = 1:nR

        skill_price = exp(log(skill_prices[l, S] - w_mods[R]))

        c = (1-taxes[l]) * (h_grid[h]*skill_price)/prices[l]
        val = utility(c)

        #possible realizations of next-period human capital
        h_vals = ε3_grid[S].*h_grid[h] #possible HC
        c_vals = h_vals.*(skill_price/prices[l]) .* (1-taxes[l]) #add marriage to cnosumption
        for i = 1:ne
            val += β * utility(c_vals[i]) * (1.0/ne) #add to expectation, using the fact that all shocks are equally likely
        end
        val += (1+β) * (α_1 * amen_1[l] + α_2 * amen_2[l] + α_3 * amen_3[l]) #add amenities from current, future period
        res.val_func_3[h, l, ap, 2, S, R] = val #update
    end
end

#period 3 objective function to max
function Obj_func_v3(x::Float64, t::Float64, prim::Primitives, est::Estimands, res::Results, grids::Grids, h::Int64, l::Int64, ap::Int64, S::Int64, R::Int64)

    @unpack θ, ξ, ϕ, α_1, η_1, η_2, η_3, Δ_1, Δ_2, Δ_3, Δ_4, Δ_5, σ_ζ, σ_η = est #unpacking
    @unpack val_func_2 = res #more unpacking
    @unpack ε3_grid_hs, ε3_grid_coll, h_grid, a_grid = grids #even more unpacking
    @unpack β, skill_prices, prices, ne, govt_inv, nh, s_t_ratios, amenity, μ_h, σ_h, taxes, tuitions, nl, hc_norm, prox_mat, aid_sched  = prim #you guessed it!
    @unpack η_1_b, η_1_h, Δ_mod_b, Δ_mod_h, ξ_b, ξ_h = est
    @unpack w_mods = prim

    ε3_grid = [ε3_grid_hs, ε3_grid_coll]

    #get relevant race stuff
    Δ_mods = [0.0, Δ_mod_b, Δ_mod_h]
    η_opts = [η_1, η_1_b, η_1_h]
    skill_price = exp(log(skill_prices[l, S] - w_mods[R])) #skill price to use
    Δ_mod_R = Δ_mods[R]
    η_R = η_opts[R]
    ξ_opts = [ξ, ξ_b, ξ_h]
    ξ_R = ξ_opts[R]

    q = (1.0 + θ) #pareto rule for consumption

    #consumption given choice of x
    #c = (1-t)*(h_grid[h]*skill_prices[l])/prices[l] - x
    c = (1-taxes[l]) * (1-t)*(h_grid[h]*skill_price)/prices[l] - x
    val = -Inf
    if c>0
        val = q*utility(c/1.5) #update to utilty if positive consumption. Added in consumption equilvant scale
    end

    #compute kid's human capital and expected utility
    h_child = ξ_R * a_grid[ap] * ((x + (1-ϕ)*govt_inv[l]/prices[l])^(1-ϕ)) * (t + ϕ * govt_inv[l]/(prices[l] * s_t_ratios[l] * exp(μ_h + (σ_h^2)/2)))^ϕ
    h2 = get_index(h_child, h_grid)
    h2_f, h2_c = Int64(floor(h2)), Int64(ceil(h2))
    weight = h2 - h2_f

    #compute kid HS, coll value function with usual TIEV formula
    kid_val_hs = Base.MathConstants.eulergamma*σ_ζ
    kid_val_coll = Base.MathConstants.eulergamma*σ_ζ

    kid_val_hs_sum = 0
    kid_val_coll_sum = 0
    for lp = 1:nl
        val_lp_hs, val_lp_coll = 0, 0

        #factor in moving cost
        movecost_hs, movecost_coll = 0.0, 0.0
        mig = 1
        if lp!=l
            movecost_hs = Δ_1 - Δ_2 * (h_child) - Δ_3 * prox_mat[l, lp] - Δ_5*amenity[lp] + Δ_mod_R
            movecost_coll = Δ_1 - Δ_2 * (h_child) - Δ_3 * prox_mat[l, lp] - Δ_4 - Δ_5*amenity[lp] + Δ_mod_R
            mig=2
        end

        val_lp_hs += (1/σ_ζ) * (val_func_2[ap, h2_f, lp, mig, 1, R] * (1-weight) + val_func_2[ap, h2_c, lp, mig, 1, R] * weight)
        val_lp_coll += (1/σ_ζ) * (val_func_2[ap, h2_f, lp, mig, 2, R] * (1-weight) + val_func_2[ap, h2_c, lp, mig, 2, R] * weight)

        val_lp_hs-=(1/σ_ζ)*movecost_hs
        val_lp_coll-=(1/σ_ζ)*movecost_coll

        kid_val_hs_sum+=exp(val_lp_hs)
        kid_val_coll_sum+=exp(val_lp_coll)
    end
    kid_val_hs += σ_ζ * log(kid_val_hs_sum)
    kid_val_coll += σ_ζ * log(kid_val_coll_sum)
    dist_η = Normal(0, σ_η) #distribution of college preference shock

    #possible realizations of next-period human capital
    h_vals = ε3_grid[S].*h_grid[h] #possible HC

    for i = 1:ne
        pI = h_vals[i] * skill_prices[l, S]

        #consumption, adding in grants
        c_val_hs = h_vals[i]*(skill_price/prices[l]) * (1-taxes[l]) #add marriage to cnosumption
        c_val_coll = h_vals[i]*(skill_price/prices[l]) * (1-taxes[l]) - (tuitions[l]-Grants(pI, ap, aid_sched))/prices[l] #add marriage to cnosumption

        threshold = utility(c_val_hs) + (1+θ)*kid_val_hs - (utility(c_val_coll) + η_3 * (S-1) + (1+θ)*(η_R + η_2 * a_grid[ap]  + kid_val_coll))
        prob_coll = 1 - cdf(dist_η, threshold) #probability family decides on college
        val += β * prob_coll * (utility(c_val_coll) + η_3*(S-1) + θ * (kid_val_coll + η_R + η_2*a_grid[ap] + Mills(threshold, 0.0, σ_η))) * (1/ne) #parent private utility
        val += β * (1-prob_coll) * (utility(c_val_hs) + θ * kid_val_hs) * (1/ne) #parent private utility. Should puzzle over whether this is right later.
    end
    #val += (1+β) * α_1 * amenity[l] #add amenities
    val #return value
end##

#period 2 Bellman operator
function Bellman_2(prim::Primitives, est::Estimands, res::Results, grids::Grids; l_select::Int64 = 0, cfact::Int64 = 0, aopt::Int64 = 0)
    @unpack nl, na, nh, amenity, nS, amen_grid, nR = prim
    @unpack θ, α_1, α_2, α_3 = est
    @unpack val_func_3 = res
    c_boost = (cfact*10000)/(18*47961) #consumption boost from retention policy counterfactual

    v2_next = SharedArray{Float64}(na, nh, nl, nl, nS, nR) #initialize new guess for period 2 value function
    n_grid = (0:0.05:0.5)
    nn = length(n_grid)

    #figure out what our amenities are
    amen_1 = amenity
    amen_2 = zeros(50)
    amen_3 = zeros(50)
    if aopt==4 #switch to college gratio
        amen_1 = amen_grid[:, 7]
    elseif aopt>0 && aopt<4
        a_ind = 2*(aopt-1) + 1
        amen_2 = amen_grid[:, a_ind]
        amen_3 = amen_grid[:, a_ind+1]
    end

    #iterate over state space
    @sync @distributed for l = 1:nl
        for a = 1:na, h = 1:nh, S = 1:nS, mig = 1:2, R = 1:nR
            ###solve for optimal n if didn't move###
            obj(n) = -Obj_func_v2(n, prim, est, res, grids, a, h, l, mig, S, R; l_select = l_select, c_boost=c_boost)
            n_choice = Optim.optimize(obj, 0.0, 1.0) #solve
            v2_next[a, h, l, mig, S, R] = -n_choice.minimum + (α_1 * amen_1[l] + α_2 * amen_2[l] + α_3 * amen_3[l])
            res.n_func[a, h, l, mig, S, R] = n_choice.minimizer

            #v2_next[a, h, l, lh, S] = -n_choice.minimum + α_1 * pops[l] #update value function next guess
            #res.n_func[a, h, l, lh, S] = n_choice.minimizer #update policy function
        end
    end
    v2_next #return new guess of v2
end

#objectrive function to maximize in period 2
function Obj_func_v2(n::Float64, prim::Primitives, est::Estimands, res::Results, grids::Grids, a::Int64, h::Int64, l::Int64, mig::Int64, S::Int64, R::Int64;
    report::Int64=0, l_select::Int64 = 0, c_boost::Float64 = 0.0)
    @unpack θ, κ, σ_ζ, σ_ε2, μ_a, σ_a, ρ_ha, Δ_1, Δ_2, Δ_3, Δ_4, Δ_5, Δ_mod = est #unpacking
    #@unpack θ, κ_hs, κ_coll, σ_ζ, σ_ε2, μ_a, σ_a, ρ_ha, Δ_1, Δ_2, Δ_3, Δ_4, Δ_5 = est #unpacking
    #@unpack θ, κ, σ_ζ, σ_ε2, μ_a, σ_a, ρ_ha, Δ_1, Δ_2 = est #unpacking
    @unpack a_grid, ε2_grid, h_grid = grids #more unpacking
    @unpack β, skill_prices, prices, nh, na, ne, nl, prox_mat, σ_h, μ_h, amenity, prox_mat, taxes, μ_h_raw, divs, regions = prim #even more unpacking
    @unpack kid_marr_hs, kid_fert_marr_hs, kid_fert_nonmarr_hs, kid_marr_coll, kid_fert_marr_coll, kid_fert_nonmarr_coll = prim
    @unpack hc_norm, mig_mat = prim
    @unpack val_func_3 = res
    @unpack μ_a_b, μ_a_h, Δ_mod_b, Δ_mod_h = est
    @unpack w_mods, marr_mod_hs, marr_mod_coll, fert_mod_coll_marr, fert_mod_coll_nmarr, fert_mod_hs_marr, fert_mod_hs_nmarr = prim
    @unpack μ_a_r_1, μ_a_r_2, μ_a_r_3, μ_a_r_4 = est


    #get relevant race stuff
    Δ_mods = [0.0, Δ_mod_b, Δ_mod_h]
    μ_a_opts = [μ_a, μ_a_r_2, μ_a_r_3, μ_a_r_4]
    skill_price = exp(log(skill_prices[l, S] - w_mods[R])) #skill price to use
    Δ_mod_R = Δ_mods[R]
    #

    marr_race_collect = [marr_mod_hs, marr_mod_coll]
    fert_race_marr_collect = [fert_mod_hs_marr, fert_mod_coll_marr]
    fert_race_nmarr_collect = [fert_mod_hs_nmarr, fert_mod_coll_nmarr]
    marr_R_mod = marr_race_collect[S][R]
    fert_marr_R_mod = fert_race_marr_collect[S][R]
    fert_nmarr_R_mod = fert_race_nmarr_collect[S][R]

    marr_collect = [kid_marr_hs, kid_marr_coll]
    fert_marr_collect = [kid_fert_marr_hs, kid_fert_marr_coll]
    fert_nonmarr_collect = [kid_fert_nonmarr_hs, kid_fert_nonmarr_coll]
    kid_marr = marr_collect[S]
    kid_fert_marr = fert_marr_collect[S]
    kid_fert_nonmarr = fert_nonmarr_collect[S]

    time_coll = (2/9) * (S-1)

    #consumption given choice of n
    c = (1-taxes[l]) * skill_price * h_grid[h] * (1.0-n - time_coll) / prices[l]

    #counterfactual adjustment for retention policy
    if l == l_select && S==2
        c += c_boost
    end

    val = utility(c) #starting value

    #begin construction of expected value of v2
    exp_v2 = Base.MathConstants.eulergamma*σ_ζ

    #compute possible next-period HC realizations and resultant kid ability draws -- only have to do this once to save time
    prob_marr, prob_fert_marr, prob_fert_nonmarr = zeros(ne), zeros(ne), zeros(ne)
    h3_vals = zeros(ne)
    h3_indices = zeros(ne)
    ap_probs = zeros(ne, na, 4)
    exp_h = 0.0
    for s = 1:ne, reg = 1:4
        h_val = ε2_grid[s]*(a_grid[a]*(n*h_grid[h])^κ + h_grid[h]) #stochastic HC realization
        h3_indices[s] = get_index(h_val, h_grid)
        h3_vals[s] = h_val

        #probabilites of marriage, kids
        prob_marr[s] = 𝚽(kid_marr[l, 2] + h_val*kid_marr[l, 3] +  h_val^2*kid_marr[l, 4] +  h_val^3*kid_marr[l, 5] + kid_marr[l, 7] * (mig==2) + marr_R_mod) #now with migration penalty added in
        prob_fert_marr[s] = 𝚽(kid_fert_marr[l,2] + h_val*kid_fert_marr[l,3] +  h_val^2*kid_fert_marr[l,4] +  h_val^3*kid_fert_marr[l,5] + fert_marr_R_mod)
        prob_fert_nonmarr[s] = 𝚽(kid_fert_nonmarr[l,2] + h_val*kid_fert_nonmarr[l,3] +  h_val^2*kid_fert_nonmarr[l,4] +  h_val^3*kid_fert_nonmarr[l,5] + fert_nmarr_R_mod)

        #cap
        if h_val>4
            prob_marr[s] = kid_marr[l, 6]
            prob_fert_marr[s], prob_fert_nonmarr[s] = kid_fert_marr[l, 6], kid_fert_nonmarr[l, 6]
        end

        #get values of kids possible ability using Bayes rule
        μ_a_R = μ_a_opts[reg]
        mean = μ_a_R + ρ_ha*(σ_a/σ_h)*(log(h_val) - μ_h) #mean of log of ability
        var = (1-ρ_ha^2)*σ_a^2 #variance of log of ability
        dist = LogNormal(mean, sqrt(var))

        for ap = 1:na
            ap_probs[s, ap, reg] = pdf(dist, a_grid[ap])
        end
        ap_probs[s, :, reg]./=sum(ap_probs[s, :, reg])
    end
    exp_h = exp_h/ne #average

    #loop over possible next states
    val_sum = 0.0
    for lp = 1:nl
        val_lp = 0

        reg = regions[lp]

        #factor in moving cost
        movecost = 0.0
        if lp!=l
            #movecost = Δ_1 - Δ_2 * (exp_h-hc_norm) - Δ_3 * prox_mat[l, lp] - Δ_4 * (S-1) + Δ_5 * pops[lp]
            #movecost = Δ_1 - Δ_2 * (exp_h) - Δ_3 * prox_mat[l, lp] - Δ_4 * (S-1) + Δ_mod #now adjusted for larger later-period moving cost
            movecost = Δ_mod - Δ_2 * (exp_h) - Δ_3 * prox_mat[l, lp] - Δ_4 * (S-1) - Δ_5*amenity[lp] + Δ_mod_R  #now adjusted for larger later-period moving cost
        end

        for s = 1:ne, ap = 1:na #loop over stochastic next-period states

            h_val, h_ind = h3_vals[s], h3_indices[s] #for exposition
            h_ind_f, h_ind_c = Int64(floor(h_ind)), Int64(ceil(h_ind))
            weight = h_ind - h_ind_f

            #####add stuff to continuation value
            #kids, non-married
            val_lp += (1/σ_ζ) * val_func_3[h_ind_f, lp, ap, 1, S, R] * (1/ne) * ap_probs[s, ap, reg] * (1-prob_marr[s]) * prob_fert_nonmarr[s] * (1-weight)
            val_lp += (1/σ_ζ) * val_func_3[h_ind_c, lp, ap, 1, S, R] * (1/ne) * ap_probs[s, ap, reg] * (1-prob_marr[s]) * prob_fert_nonmarr[s] * (weight)

            #no kid, no marriage
            val_lp += (1/σ_ζ) * val_func_3[h_ind_f, lp, ap, 2, S, R] * (1/ne) * ap_probs[s, ap, reg] * (1-prob_marr[s]) * (1-prob_fert_nonmarr[s]) * (1-weight)
            val_lp += (1/σ_ζ) * val_func_3[h_ind_c, lp, ap, 2, S, R] * (1/ne) * ap_probs[s, ap, reg] * (1-prob_marr[s]) * (1-prob_fert_nonmarr[s]) * (weight)

            #kids, married
            val_lp += (1/σ_ζ) * val_func_3[h_ind_f, lp, ap, 1, S, R] * (1/ne) * ap_probs[s, ap, reg] * (prob_marr[s]) * prob_fert_marr[s] * (1-weight) * (1+θ)
            val_lp += (1/σ_ζ) * val_func_3[h_ind_c, lp, ap, 1, S, R] * (1/ne) * ap_probs[s, ap, reg] * (prob_marr[s]) * prob_fert_marr[s] * (weight) * (1+θ)

            #no kid, marriage
            val_lp += (1/σ_ζ) * val_func_3[h_ind_f, lp, ap, 2, S, R] * (1/ne) * ap_probs[s, ap, reg] * (prob_marr[s]) * (1-prob_fert_marr[s]) * (1-weight) * (1+θ)
            val_lp += (1/σ_ζ) * val_func_3[h_ind_c, lp, ap, 2, S, R] * (1/ne) * ap_probs[s, ap, reg] * (prob_marr[s]) * (1-prob_fert_marr[s]) * (weight) * (1+θ)
        end

        val_lp-=(1/σ_ζ)*movecost
        val_sum+=exp(val_lp)
    end

    exp_v2 += σ_ζ * log(val_sum)
    val+=β*exp_v2 #add
    val #return
end

##################
